{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "custom_training_walkthrough",
      "provenance": [],
      "private_outputs": true,
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.4"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "rwxGnsA92emp"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "colab_type": "code",
        "id": "CPII1rGR2rF9",
        "colab": {}
      },
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "JtEZ1pCPn--z"
      },
      "source": [
        "# カスタム訓練：ウォークスルー"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "GV1F7tVTN3Dn"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/customization/custom_training_walkthrough\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/ja/tutorials/customization/custom_training_walkthrough.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/ja/tutorials/customization/custom_training_walkthrough.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/ja/tutorials/customization/custom_training_walkthrough.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SQ8UtWX18HNR",
        "colab_type": "text"
      },
      "source": [
        "Note: これらのドキュメントは私たちTensorFlowコミュニティが翻訳したものです。コミュニティによる 翻訳は**ベストエフォート**であるため、この翻訳が正確であることや[英語の公式ドキュメント](https://www.tensorflow.org/?hl=en)の 最新の状態を反映したものであることを保証することはできません。 この翻訳の品質を向上させるためのご意見をお持ちの方は、GitHubリポジトリ[tensorflow/docs](https://github.com/tensorflow/docs)にプルリクエストをお送りください。 コミュニティによる翻訳やレビューに参加していただける方は、 [docs-ja@tensorflow.org メーリングリスト](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-ja)にご連絡ください。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "LDrzLFXE8T1l"
      },
      "source": [
        "このガイドでは機械学習を使ってアヤメの花を品種によって*分類*します。TensorFlowを使って下記のことを行います。\n",
        "\n",
        "1. モデルを構築し、\n",
        "2. このモデルをサンプルデータを使って訓練し、\n",
        "3. このモデルを使って未知のデータに対する予測を行う\n",
        "\n",
        "## TensorFlow プログラミング\n",
        "\n",
        "このガイドでは、次のような TensorFlow の高レベルの概念を用います。\n",
        "\n",
        "* TensorFlow の既定の [Eager Execution](https://www.tensorflow.org/guide/eager) 開発環境を用いて、\n",
        "* データを [Datasets API](https://www.tensorflow.org/guide/datasets) を使ってインポートし、\n",
        "* TensorFlow の [Keras API](https://keras.io/getting-started/sequential-model-guide/) を使ってモデルと層を構築する\n",
        "\n",
        "このチュートリアルは、多くの TensorFlow のプログラムと同様に下記のように構成されています。\n",
        "\n",
        "1. データセットのインポートとパース\n",
        "2. モデルのタイプの選択\n",
        "3. モデルの訓練\n",
        "4. モデルの有効性の評価\n",
        "5. 訓練済みモデルを使った予測"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "yNr7H-AIoLOR"
      },
      "source": [
        "## プログラムの設定"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "1J3AuPBT9gyR"
      },
      "source": [
        "### インポートの構成\n",
        "\n",
        "TensorFlow とその他必要な Python のモジュールをインポートします。TensorFlow は既定で [Eager Execution](https://www.tensorflow.org/guide/eager) を使って演算結果を即時に評価し、後で実行される [計算グラフ](https://www.tensorflow.org/guide/graphs) を構築する代わりに具体的な値を返します。REPL や `python` の対話的コンソールを使っている方にとっては、馴染みのあるものです。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "jElLULrDhQZR",
        "colab": {}
      },
      "source": [
        "from __future__ import absolute_import, division, print_function, unicode_literals\n",
        "\n",
        "import os\n",
        "import matplotlib.pyplot as plt"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "bfV2Dai0Ow2o",
        "colab": {}
      },
      "source": [
        "try:\n",
        "  # %tensorflow_version only exists in Colab.\n",
        "  %tensorflow_version 2.x\n",
        "except Exception:\n",
        "  pass\n",
        "import tensorflow as tf"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "g4Wzg69bnwK2",
        "colab": {}
      },
      "source": [
        "print(\"TensorFlow version: {}\".format(tf.__version__))\n",
        "print(\"Eager execution: {}\".format(tf.executing_eagerly()))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Zx7wc0LuuxaJ"
      },
      "source": [
        "## アヤメ分類問題\n",
        "\n",
        "あなたが植物学者で、発見したアヤメの花を分類する自動的な方法を探しているとします。機械学習には花を統計的に分類するためのアルゴリズムがたくさんあります。たとえば、洗練された機械学習プログラムを使うと、写真を元に花を分類することができます。私達の願望はもう少し控えめです。アヤメの花を [sepals（萼片）](https://en.wikipedia.org/wiki/Sepal) と [petals（花弁）](https://en.wikipedia.org/wiki/Petal) の長さと幅を使って分類することにしましょう。\n",
        "\n",
        "アヤメ属にはおよそ300の種がありますが、ここで扱うプログラムは次の3つの種類のみを分類します。\n",
        "\n",
        "* Iris setosa\n",
        "* Iris virginica\n",
        "* Iris versicolor\n",
        "\n",
        "<table>\n",
        "  <tr><td>\n",
        "    <img src=\"https://www.tensorflow.org/images/iris_three_species.jpg\"\n",
        "         alt=\"Petal geometry compared for three iris species: Iris setosa, Iris virginica, and Iris versicolor\">\n",
        "  </td></tr>\n",
        "  <tr><td align=\"center\">\n",
        "    <b>図1.</b> <a href=\"https://commons.wikimedia.org/w/index.php?curid=170298\">Iris setosa</a> (by <a href=\"https://commons.wikimedia.org/wiki/User:Radomil\">Radomil</a>, CC BY-SA 3.0), <a href=\"https://commons.wikimedia.org/w/index.php?curid=248095\">Iris versicolor</a>, (by <a href=\"https://commons.wikimedia.org/wiki/User:Dlanglois\">Dlanglois</a>, CC BY-SA 3.0), and <a href=\"https://www.flickr.com/photos/33397993@N05/3352169862\">Iris virginica</a> (by <a href=\"https://www.flickr.com/photos/33397993@N05\">Frank Mayfield</a>, CC BY-SA 2.0).<br/>&nbsp;\n",
        "  </td></tr>\n",
        "</table>\n",
        "\n",
        "幸いにして、萼片と花弁の計測値を使った [120のアヤメのデータセット](https://en.wikipedia.org/wiki/Iris_flower_data_set) を作ってくれた方がいます。これは初歩の機械学習による分類問題でよく使われる古典的なデータセットです。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "3Px6KAg0Jowz"
      },
      "source": [
        "## 訓練用データセットのインポートとパース\n",
        "\n",
        "データセットファイルをダウンロードし、この Python プログラムで使えるような構造に変換します。\n",
        "\n",
        "### データセットのダウンロード\n",
        "\n",
        "[tf.keras.utils.get_file](https://www.tensorflow.org/api_docs/python/tf/keras/utils/get_file) 関数を使って訓練用データセットをダウンロードします。この関数は、ダウンロードしたファイルのパスを返します。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "J6c7uEU9rjRM",
        "colab": {}
      },
      "source": [
        "train_dataset_url = \"https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv\"\n",
        "\n",
        "train_dataset_fp = tf.keras.utils.get_file(fname=os.path.basename(train_dataset_url),\n",
        "                                           origin=train_dataset_url)\n",
        "\n",
        "print(\"Local copy of the dataset file: {}\".format(train_dataset_fp))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "qnX1-aLors4S"
      },
      "source": [
        "### データの観察\n",
        "\n",
        "このデータセット `iris_training.csv` は、表形式のデータをコンマ区切り値（CSV）形式で格納したプレインテキストファイルです。`head -n5` コマンドを使って最初の5つのエントリを見てみましょう。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "FQvb_JYdrpPm",
        "colab": {}
      },
      "source": [
        "!head -n5 {train_dataset_fp}"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "kQhzD6P-uBoq"
      },
      "source": [
        "こうしてデータセットを見ることで次のことがわかります。\n",
        "\n",
        "1. 最初の行はデータセットに関する情報を含むヘッダである\n",
        "  * 全体で120のサンプルがある。各サンプルには4つの特徴量があり、3つのラベルの1つが当てはまる。\n",
        "2. 続く行はデータレコードで、1行が1つの [*サンプル(example)*](https://developers.google.com/machine-learning/glossary/#example) である。ここでは、\n",
        "  * 最初の4つのフィールドはサンプルの特徴を示す [*特徴量(features)*](https://developers.google.com/machine-learning/glossary/#feature) である。ここで各フィールドは花の計測値を表す浮動小数点値をもつ。\n",
        "  * 最後の列は予想したい [*ラベル(label)*](https://developers.google.com/machine-learning/glossary/#label) である。このデータセットでは、花の名前に対応する整数値、0、1、2 のいずれかである。\n",
        "\n",
        "コードを書いてみましょう。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "9Edhevw7exl6",
        "colab": {}
      },
      "source": [
        "# CSV ファイルの列の順序\n",
        "column_names = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species']\n",
        "\n",
        "feature_names = column_names[:-1]\n",
        "label_name = column_names[-1]\n",
        "\n",
        "print(\"Features: {}\".format(feature_names))\n",
        "print(\"Label: {}\".format(label_name))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "CCtwLoJhhDNc"
      },
      "source": [
        "ラベルはそれぞれ文字列の名前（たとえば、\"setosa\"）に関連付けられていますが、機械学習は大抵の場合、数値に依存しています。ラベルの数値は名前の表現にマップされます。たとえば次のようになります。\n",
        "\n",
        "* `0`: Iris setosa\n",
        "* `1`: Iris versicolor\n",
        "* `2`: Iris virginica\n",
        "\n",
        "特徴量とラベルについてもっと知りたい場合には、 [ML Terminology section of the Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/framing/ml-terminology) を参照してください。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "sVNlJlUOhkoX",
        "colab": {}
      },
      "source": [
        "class_names = ['Iris setosa', 'Iris versicolor', 'Iris virginica']"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "dqPkQExM2Pwt"
      },
      "source": [
        "### `tf.data.Dataset` の作成\n",
        "\n",
        "TensorFlow の [Dataset API](https://www.tensorflow.org/guide/datasets) はデータのモデルへのロードについて、多くの一般的なケースを扱います。この API は、データを読み込んで訓練に使われる形式へ変換するための高レベル API です。詳しくは、 [Datasets Quick Start guide](https://www.tensorflow.org/get_started/datasets_quickstart) を参照してください。\n",
        "\n",
        "今回のデータセットは CSV 形式のテキストファイルなので、データをパースして適切なフォーマットに変換するため、 [make_csv_dataset](https://www.tensorflow.org/api_docs/python/tf/data/experimental/make_csv_dataset) 関数を使います。この関数はモデルの訓練のためのデータを生成するので、既定の動作はデータをシャッフルし (`shuffle=True, shuffle_buffer_size=10000`) 、データセットを永遠に繰り返すこと (`num_epochs=None`) です。また、 [batch_size](https://developers.google.com/machine-learning/glossary/#batch_size) パラメータも設定します。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "WsxHnz1ebJ2S",
        "colab": {}
      },
      "source": [
        "batch_size = 32\n",
        "\n",
        "train_dataset = tf.data.experimental.make_csv_dataset(\n",
        "    train_dataset_fp,\n",
        "    batch_size,\n",
        "    column_names=column_names,\n",
        "    label_name=label_name,\n",
        "    num_epochs=1)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "gB_RSn62c-3G"
      },
      "source": [
        "`make_csv_dataset` 関数は、 `(features, label)` というペアの `tf.data.Dataset` を返します。ここで、 `features` は、 `{'feature_name': value}` というディクショナリです。\n",
        "\n",
        "この `Dataset` オブジェクトは反復可能オブジェクトです。特徴量のバッチを1つ見てみましょう。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "iDuG94H-C122",
        "colab": {}
      },
      "source": [
        "features, labels = next(iter(train_dataset))\n",
        "\n",
        "print(features)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "E63mArnQaAGz"
      },
      "source": [
        "「特徴量のようなもの」はグループ化、言い換えると *バッチ化* されることに注意してください。それぞれのサンプル行のフィールドは、対応する特徴量配列に追加されます。これらの特徴量配列に保存されるサンプルの数を設定するには、 `batch_size` を変更します。\n",
        "\n",
        "上記のバッチから特徴量のいくつかをプロットしてみると、いくつかのクラスタが見られます。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "me5Wn-9FcyyO",
        "colab": {}
      },
      "source": [
        "plt.scatter(features['petal_length'],\n",
        "            features['sepal_length'],\n",
        "            c=labels,\n",
        "            cmap='viridis')\n",
        "\n",
        "plt.xlabel(\"Petal length\")\n",
        "plt.ylabel(\"Sepal length\")\n",
        "plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "YlxpSyHlhT6M"
      },
      "source": [
        "モデル構築のステップを単純化するため、特徴量のディクショナリを `(batch_size, num_features)` という形状の単一の配列にパッケージし直す関数を作成します。\n",
        "\n",
        "この関数は、テンソルのリストから値を受け取り、指定された次元で結合されたテンソルを作成する [tf.stack](https://www.tensorflow.org/api_docs/python/tf/stack) メソッドを使っています。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "jm932WINcaGU",
        "colab": {}
      },
      "source": [
        "def pack_features_vector(features, labels):\n",
        "  \"\"\"特徴量を1つの配列にパックする\"\"\"\n",
        "  features = tf.stack(list(features.values()), axis=1)\n",
        "  return features, labels"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "V1Vuph_eDl8x"
      },
      "source": [
        "次に、 [tf.data.Dataset.map](https://www.tensorflow.org/api_docs/python/tf/data/dataset/map) メソッドを使って `(features,label)` ペアそれぞれの `features` を訓練用データセットにパックします。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "ZbDkzGZIkpXf",
        "colab": {}
      },
      "source": [
        "train_dataset = train_dataset.map(pack_features_vector)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "NLy0Q1xCldVO"
      },
      "source": [
        "この `Dataset` の特徴量要素は、`(batch_size, num_features)` の形状をした配列になっています。 最初のサンプルをいくつか見てみましょう。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "kex9ibEek6Tr",
        "colab": {}
      },
      "source": [
        "features, labels = next(iter(train_dataset))\n",
        "\n",
        "print(features[:5])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "LsaVrtNM3Tx5"
      },
      "source": [
        "## モデルのタイプの選択\n",
        "\n",
        "### なぜモデルか？\n",
        "\n",
        "[*モデル(model)*](https://developers.google.com/machine-learning/crash-course/glossary#model) は特徴量とラベルの間の関係です。アヤメの分類問題の場合、モデルは萼片と花弁の計測値と予測されるアヤメの種類の関係を定義します。単純なモデルであれば、2、3行の数式で記述できますが、複雑な機械学習モデルには多数のパラメータがあり、かんたんに要約できるものではありません。\n",
        "\n",
        "機械学習を*使わずに*これらの4つの特徴量とアヤメの品種の関係を決定することは可能でしょうか？言い換えると、従来のプログラミング技術（たとえば、たくさんの条件文）を使ってモデルを作ることは可能でしょうか？おそらく、十分に時間をかけてデータセットを分析して、萼片と花弁の計測値と特定の品種との関係を決定すれば可能でしょう。もっと複雑なデータセットの場合には、これは困難であり、おそらく不可能です。機械学習を使ったよいアプローチでは、*あなたに代わってモデルを決定*してくれます。適切なタイプの機械学習モデルに、十分な量の典型的なサンプルを投入すれば、プログラムがあなたのためにこの関係性をみつけ出してくれるでしょう。\n",
        "\n",
        "###  モデルの選択\n",
        "\n",
        "訓練すべきモデルの種類を決める必要があります。モデルにはたくさんの種類があり、よいものを選択するには経験が必要です。このチュートリアルでは、アヤメの分類問題を解くために、ニューラルネットワークを使用します。[*ニューラルネットワーク(Neural networks)*](https://developers.google.com/machine-learning/glossary/#neural_network) は、特徴量とラベルの間の複雑な関係をみつけることができます。ニューラルネットワークは、1つかそれ以上の [*隠れ層(hidden layers)*](https://developers.google.com/machine-learning/glossary/#hidden_layer) で構成された高度なグラフ構造です。それぞれの隠れ層には1つ以上の [*ニューロン(neurons)*](https://developers.google.com/machine-learning/glossary/#neuron) があります。ニューラルネットワークにはいくつかのカテゴリーがありますが、このプログラムでは、[*全結合ニューラルネットワーク(fully-connected neural network または dense neural network)*](https://developers.google.com/machine-learning/glossary/#fully_connected_layer) を使用します。全結合ニューラルネットワークでは、1つの層の中のニューロンすべてが、前の層の*すべての*ニューロンからの入力を受け取ります。例として、図2に入力層、2つの隠れ層、そして、出力層からなる密なニューラルネットワークを示します。\n",
        "\n",
        "<table>\n",
        "  <tr><td>\n",
        "    <img src=\"https://www.tensorflow.org/images/custom_estimators/full_network.png\"\n",
        "         alt=\"A diagram of the network architecture: Inputs, 2 hidden layers, and outputs\">\n",
        "  </td></tr>\n",
        "  <tr><td align=\"center\">\n",
        "    <b>図2.</b> 特徴量と隠れ層、予測をもつニューラルネットワーク<br/>&nbsp;\n",
        "  </td></tr>\n",
        "</table>\n",
        "\n",
        "図2のモデルが訓練されてラベルの付いていないサンプルを受け取ったとき、モデルは3つの予測値を返します。予測値はサンプルの花が与えられた3つの品種のそれぞれである可能性を示します。この予測は、 [*推論(inference)*](https://developers.google.com/machine-learning/crash-course/glossary#inference) とも呼ばれます。このサンプルでは、出力の予測値の合計は 1.0 になります。図2では、この予測値は *Iris setosa* が `0.02` 、 *Iris versicolor* が `0.95` 、 *Iris virginica* が `0.03` となっています。これは、モデルがこのラベルのない花を、95% の確率で *Iris versicolor* であると予測したことを意味します。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "W23DIMVPQEBt"
      },
      "source": [
        "### Keras を使ったモデル構築\n",
        "\n",
        "TensorFlow の [tf.keras](https://www.tensorflow.org/api_docs/python/tf/keras) API は、モデルと層を作成するためのおすすめの方法です。Keras がすべてを結びつけるという複雑さを引き受けてくれるため、モデルや実験の構築がかんたんになります。\n",
        "\n",
        "[tf.keras.Sequential](https://www.tensorflow.org/api_docs/python/tf/keras/Sequential) モデルは層が一列に積み上げられたものです。このクラスのコンストラクタは、レイヤーインスタンスのリストを引数として受け取ります。今回の場合、それぞれ 10 のノードをもつ 2つの [Dense](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense) レイヤーと、ラベルの予測値を表す3つのノードからなる出力レイヤーです。最初のレイヤーの `input_shape` パラメータがデータセットの特徴量の数に対応しており、これは必須パラメータです。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "2fZ6oL2ig3ZK",
        "colab": {}
      },
      "source": [
        "model = tf.keras.Sequential([\n",
        "  tf.keras.layers.Dense(10, activation=tf.nn.relu, input_shape=(4,)),  # input shape required\n",
        "  tf.keras.layers.Dense(10, activation=tf.nn.relu),\n",
        "  tf.keras.layers.Dense(3)\n",
        "])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "FHcbEzMpxbHL"
      },
      "source": [
        "[*活性化関数(activation function)*](https://developers.google.com/machine-learning/crash-course/glossary#activation_function)  は、そのレイヤーの各ノードの出力の形を決定します。この関数の非線形性は重要であり、それがなければモデルは 1層しかないものと等価になってしまいます。[利用可能な活性化関数](https://www.tensorflow.org/api_docs/python/tf/keras/activations) はたくさんありますが、隠れ層では [ReLU](https://developers.google.com/machine-learning/crash-course/glossary#ReLU) が一般的です。\n",
        "\n",
        "理想的な隠れ層の数やニューロンの数は問題やデータセットによって異なります。機械学習のさまざまな側面と同様に、ニューラルネットワークの最良の形を選択するには、知識と経験の両方が必要です。経験則から、一般的には隠れ層やニューロンの数を増やすとより強力なモデルを作ることができますが、効果的に訓練を行うためにより多くのデータを必要とします。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "2wFKnhWCpDSS"
      },
      "source": [
        "### モデルの使用\n",
        "\n",
        "それでは、このモデルが特徴量のバッチに対して何を行うかを見てみましょう。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "xe6SQ5NrpB-I",
        "colab": {}
      },
      "source": [
        "predictions = model(features)\n",
        "predictions[:5]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "wxyXOhwVr5S3"
      },
      "source": [
        "ご覧のように、サンプルのそれぞれは、各クラスの [ロジット(logit)](https://developers.google.com/machine-learning/crash-course/glossary#logits) 値を返します。\n",
        "\n",
        "これらのロジット値を各クラスの確率に変換するためには、 [softmax](https://developers.google.com/machine-learning/crash-course/glossary#softmax) 関数を使用します。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "_tRwHZmTNTX2",
        "colab": {}
      },
      "source": [
        "tf.nn.softmax(predictions[:5])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "uRZmchElo481"
      },
      "source": [
        "クラス間で `tf.argmax` を取ると、予測されたクラスのインデックスが得られます。しかし、このモデルはまだ訓練されていないので、予測はよいものではありません。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "-Jzm_GoErz8B",
        "colab": {}
      },
      "source": [
        "print(\"Prediction: {}\".format(tf.argmax(predictions, axis=1)))\n",
        "print(\"    Labels: {}\".format(labels))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Vzq2E5J2QMtw"
      },
      "source": [
        "## モデルの訓練\n",
        "\n",
        "[*訓練(Training)*](https://developers.google.com/machine-learning/crash-course/glossary#training) は、機械学習において、モデルが徐々に最適化されていく、あるいはモデルがデータセットを*学習する*段階です。目的は、見たことのないデータについて予測を行うため、訓練用データセットの構造を十分に学習することです。訓練用データセットを学習*しすぎる*と、予測は見たことのあるデータに対してしか有効ではなく、一般化できません。この問題は [*過学習(overfitting)*](https://developers.google.com/machine-learning/crash-course/glossary#overfitting) と呼ばれ、問題の解き方を理解するのではなく答えを丸暗記するようなものです。\n",
        "\n",
        "アヤメの分類問題は、 [*教師あり学習(supervised machine learning)*](https://developers.google.com/machine-learning/glossary/#supervised_machine_learning) の1種で、モデルはラベルの付いたサンプルを学習します。[*教師なし学習(unsupervised machine learning)*](https://developers.google.com/machine-learning/glossary/#unsupervised_machine_learning) では、サンプルにラベルはありません。そのかわり、一般的にはモデルが特徴量からパターンを発見します。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "RaKp8aEjKX6B"
      },
      "source": [
        "### 損失関数と勾配関数の定義\n",
        "\n",
        "訓練段階と評価段階では、モデルの [*損失(loss)*](https://developers.google.com/machine-learning/crash-course/glossary#loss) を計算する必要があります。損失とは、モデルの予測が望ましいラベルからどれくらい離れているかを測定するものです。言い換えると、どれくらいモデルの性能が悪いかを示します。この値を最小化、または最適化したいのです。\n",
        "\n",
        "今回のモデルでは、 `tf.keras.losses.SparseCategoricalCrossentropy` 関数を使って損失を計算します。この関数は、モデルのクラスごとの予測確率とラベルの取るべき値を使って、サンプル間の平均損失を返します。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "QOsi6b-1CXIn",
        "colab": {}
      },
      "source": [
        "loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "tMAT4DcMPwI-",
        "colab": {}
      },
      "source": [
        "def loss(model, x, y):\n",
        "  y_ = model(x)\n",
        "\n",
        "  return loss_object(y_true=y, y_pred=y_)\n",
        "\n",
        "\n",
        "l = loss(model, features, labels)\n",
        "print(\"Loss test: {}\".format(l))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "3IcPqA24QM6B"
      },
      "source": [
        "[tf.GradientTape](https://www.tensorflow.org/api_docs/python/tf/GradientTape) コンテキストを使って、モデルを最適化する際に使われる [*勾配(gradients)*](https://developers.google.com/machine-learning/crash-course/glossary#gradient) を計算しましょう。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "x57HcKWhKkei",
        "colab": {}
      },
      "source": [
        "def grad(model, inputs, targets):\n",
        "  with tf.GradientTape() as tape:\n",
        "    loss_value = loss(model, inputs, targets)\n",
        "  return loss_value, tape.gradient(loss_value, model.trainable_variables)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "lOxFimtlKruu"
      },
      "source": [
        "### オプティマイザの作成\n",
        "\n",
        "[*オプティマイザ(optimizer)*](https://developers.google.com/machine-learning/crash-course/glossary#optimizer) は、`loss` 関数を最小化するため、計算された勾配をモデルの変数に適用します。損失関数は、曲面として考えることができ（図3 参照）、歩き回ることによって最小となる点をみつけたいのです。勾配は、一番急な上りの方向を示すため、逆方向に進んで丘を下ることになります。バッチごとに損失と勾配を繰り返し計算することにより、訓練中のモデルを調節します。モデルは徐々に、損失を最小化する重みとバイアスの最適な組み合わせをみつけます。損失が小さいほど、モデルの予測がよくなります。\n",
        "\n",
        "<table>\n",
        "  <tr><td>\n",
        "    <img src=\"https://cs231n.github.io/assets/nn3/opt1.gif\" width=\"70%\"\n",
        "         alt=\"Optimization algorithms visualized over time in 3D space.\">\n",
        "  </td></tr>\n",
        "  <tr><td align=\"center\">\n",
        "    <b>図3.</b> 3次元空間における最適化アルゴリズムの時系列可視化。<br/>(Source: <a href=\"http://cs231n.github.io/neural-networks-3/\">Stanford class CS231n</a>, MIT License, Image credit: <a href=\"https://twitter.com/alecrad\">Alec Radford</a>)\n",
        "  </td></tr>\n",
        "</table>\n",
        "\n",
        "TensorFlow には、訓練に使える [最適化アルゴリズム(optimization algorithms)](https://www.tensorflow.org/api_guides/python/train) がたくさんあります。このモデルでは、[*確率的勾配降下法(stochastic gradient descent)*](https://developers.google.com/machine-learning/crash-course/glossary#gradient_descent) (SGD) アルゴリズムを実装した [tf.train.GradientDescentOptimizer](https://www.tensorflow.org/api_docs/python/tf/train/GradientDescentOptimizer) を使用します。`learning_rate` （学習率）は、イテレーションごとに丘を下る際のステップのサイズを設定します。通常これはよりよい結果を得るために調整する *ハイパーパラメータ(hyperparameter)* です。 "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "XkUd6UiZa_dF"
      },
      "source": [
        "オプティマイザーを設定しましょう。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "8xxi2NNGKwG_",
        "colab": {}
      },
      "source": [
        "optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "pJVRZ0hP52ZB"
      },
      "source": [
        "これを使って、最適化を１ステップ分計算してみます。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "rxRNTFVe56RG",
        "colab": {}
      },
      "source": [
        "loss_value, grads = grad(model, features, labels)\n",
        "\n",
        "print(\"Step: {}, Initial Loss: {}\".format(optimizer.iterations.numpy(),\n",
        "                                          loss_value.numpy()))\n",
        "\n",
        "optimizer.apply_gradients(zip(grads, model.trainable_variables))\n",
        "\n",
        "print(\"Step: {},         Loss: {}\".format(optimizer.iterations.numpy(),\n",
        "                                          loss(model, features, labels).numpy()))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "7Y2VSELvwAvW"
      },
      "source": [
        "### 訓練ループ\n",
        "\n",
        "すべての部品が揃ったので、モデルの訓練ができるようになりました。訓練ループは、モデルにデータセットのサンプルを供給し、モデルがよりよい予測を行えるようにします。下記のコードブロックは、この訓練のステップを構成します。\n",
        "\n",
        "1. *epoch（エポック）* をひとつずつ繰り返します。エポックとは、データセットをひととおり処理するということです。\n",
        "2. エポック内では、訓練用の `Dataset（データセット）` のサンプルひとつずつから、その *features（特徴量）* (`x`) と *label（ラベル）* (`y`) を取り出して繰り返し処理します。\n",
        "3. サンプルの特徴量を使って予測を行い、ラベルと比較します。予測の不正確度を測定し、それを使ってモデルの損失と勾配を計算します。\n",
        "4. `optimizer` を使って、モデルの変数を更新します。\n",
        "5. 可視化のためにいくつかの統計量を記録します。\n",
        "6. これをエポックごとに繰り返します。\n",
        "\n",
        "`num_epochs` 変数は、データセットのコレクションを何回繰り返すかという数字です。直感には反しますが、モデルを長く訓練すれば必ずよいモデルが得られるというわけではありません。`num_epochs` は、[*ハイパーパラメータ(hyperparameter)*](https://developers.google.com/machine-learning/glossary/#hyperparameter) であり、チューニングできます。適切な数を選択するには、経験と実験の両方が必要です。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "AIgulGRUhpto",
        "colab": {}
      },
      "source": [
        "## Note: このセルを再実行すると同じモデル変数が使われます\n",
        "\n",
        "# 結果をグラフ化のために保存\n",
        "train_loss_results = []\n",
        "train_accuracy_results = []\n",
        "\n",
        "num_epochs = 201\n",
        "\n",
        "for epoch in range(num_epochs):\n",
        "  epoch_loss_avg = tf.keras.metrics.Mean()\n",
        "  epoch_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()\n",
        "\n",
        "  # 訓練ループ - 32個ずつのバッチを使用\n",
        "  for x, y in train_dataset:\n",
        "    # モデルの最適化\n",
        "    loss_value, grads = grad(model, x, y)\n",
        "    optimizer.apply_gradients(zip(grads, model.trainable_variables))\n",
        "\n",
        "    # 進捗の記録\n",
        "    epoch_loss_avg(loss_value)  # Add current batch loss 現在のバッチの損失を加算\n",
        "    # 予測ラベルと実際のラベルを比較\n",
        "    epoch_accuracy(y, model(x))\n",
        "    \n",
        "  # エポックの終わり\n",
        "  train_loss_results.append(epoch_loss_avg.result())\n",
        "  train_accuracy_results.append(epoch_accuracy.result())\n",
        "\n",
        "  if epoch % 50 == 0:\n",
        "    print(\"Epoch {:03d}: Loss: {:.3f}, Accuracy: {:.3%}\".format(epoch,\n",
        "                                                                epoch_loss_avg.result(),\n",
        "                                                                epoch_accuracy.result()))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "2FQHVUnm_rjw"
      },
      "source": [
        "### 時間の経過に対する損失関数の可視化"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "j3wdbmtLVTyr"
      },
      "source": [
        "モデルの訓練の進み方をプリントするのも役立ちますが、普通はこの進捗を見るほうが*もっと*役に立ちます。[TensorBoard](https://www.tensorflow.org/guide/summaries_and_tensorboard) は TensorFlow に付属する優れた可視化ツールですが、\n",
        " `matplotlib` モジュールを使って基本的なグラフを描画することもできます。\n",
        " \n",
        "これらのグラフを解釈するにはある程度の経験が必要ですが、*損失*が低下して、*正解率*が上昇するのをぜひ見たいと思うでしょう。 "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "agjvNd2iUGFn",
        "colab": {}
      },
      "source": [
        "fig, axes = plt.subplots(2, sharex=True, figsize=(12, 8))\n",
        "fig.suptitle('Training Metrics')\n",
        "\n",
        "axes[0].set_ylabel(\"Loss\", fontsize=14)\n",
        "axes[0].plot(train_loss_results)\n",
        "\n",
        "axes[1].set_ylabel(\"Accuracy\", fontsize=14)\n",
        "axes[1].set_xlabel(\"Epoch\", fontsize=14)\n",
        "axes[1].plot(train_accuracy_results)\n",
        "plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Zg8GoMZhLpGH"
      },
      "source": [
        "## モデルの有効性評価\n",
        "\n",
        "モデルの訓練が終わったら、その性能の指標を得ることができます。\n",
        "\n",
        "*評価*とは、モデルの予測がどれだけ効果的であるかどうかを決定することを言います。アヤメの分類でモデルの有効性を見定めるには、モデルに萼片と花弁の測定値をいくつか与え、どのアヤメの品種であるかを予測させます。そして、モデルの予測と実際のラベルを比較します。たとえば、あるモデルが入力サンプルの半分で正解の品種をえらんだとすると、[*正解率(accuracy)*](https://developers.google.com/machine-learning/glossary/#accuracy) は `0.5` ということになります。図４で示すモデルはもう少し効果的であり、5つの予測中4つで正解し、正解率は 80% です。\n",
        "\n",
        "<table cellpadding=\"8\" border=\"0\">\n",
        "  <colgroup>\n",
        "    <col span=\"4\" >\n",
        "    <col span=\"1\" bgcolor=\"lightblue\">\n",
        "    <col span=\"1\" bgcolor=\"lightgreen\">\n",
        "  </colgroup>\n",
        "  <tr bgcolor=\"lightgray\">\n",
        "    <th colspan=\"4\">サンプルの特徴量</th>\n",
        "    <th colspan=\"1\">ラベル</th>\n",
        "    <th colspan=\"1\" >モデルの予測値</th>\n",
        "  </tr>\n",
        "  <tr>\n",
        "    <td>5.9</td><td>3.0</td><td>4.3</td><td>1.5</td><td align=\"center\">1</td><td align=\"center\">1</td>\n",
        "  </tr>\n",
        "  <tr>\n",
        "    <td>6.9</td><td>3.1</td><td>5.4</td><td>2.1</td><td align=\"center\">2</td><td align=\"center\">2</td>\n",
        "  </tr>\n",
        "  <tr>\n",
        "    <td>5.1</td><td>3.3</td><td>1.7</td><td>0.5</td><td align=\"center\">0</td><td align=\"center\">0</td>\n",
        "  </tr>\n",
        "  <tr>\n",
        "    <td>6.0</td> <td>3.4</td> <td>4.5</td> <td>1.6</td> <td align=\"center\">1</td><td align=\"center\" bgcolor=\"red\">2</td>\n",
        "  </tr>\n",
        "  <tr>\n",
        "    <td>5.5</td><td>2.5</td><td>4.0</td><td>1.3</td><td align=\"center\">1</td><td align=\"center\">1</td>\n",
        "  </tr>\n",
        "  <tr><td align=\"center\" colspan=\"6\">\n",
        "    <b>図4.</b> 正解率 80％ のアヤメ分類器<br/>&nbsp;\n",
        "  </td></tr>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "z-EvK7hGL0d8"
      },
      "source": [
        "### テスト用データセットの設定\n",
        "\n",
        "モデルの評価はモデルの訓練と同様です。もっとも大きな違いは、サンプルが訓練用データセットではなく[*テスト用データセット(test set)*](https://developers.google.com/machine-learning/crash-course/glossary#test_set) からのものであるという点です。モデルの有効性を正しく評価するには、モデルの評価に使うサンプルは訓練用データセットのものとは違うものでなければなりません。\n",
        "\n",
        "テスト用 `Dataset` の設定は、訓練用 `Dataset` の設定と同様です。CSV ファイルをダウンロードし、値をパースしてからすこしシャッフルを行います。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Ps3_9dJ3Lodk",
        "colab": {}
      },
      "source": [
        "test_url = \"https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv\"\n",
        "\n",
        "test_fp = tf.keras.utils.get_file(fname=os.path.basename(test_url),\n",
        "                                  origin=test_url)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "SRMWCu30bnxH",
        "colab": {}
      },
      "source": [
        "test_dataset = tf.data.experimental.make_csv_dataset(\n",
        "    test_fp,\n",
        "    batch_size,\n",
        "    column_names=column_names,\n",
        "    label_name='species',\n",
        "    num_epochs=1,\n",
        "    shuffle=False)\n",
        "\n",
        "test_dataset = test_dataset.map(pack_features_vector)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "HFuOKXJdMAdm"
      },
      "source": [
        "### テスト用データセットでのモデルの評価\n",
        "\n",
        "訓練段階とは異なり、モデルはテスト用データの 1[エポック(epoch)](https://developers.google.com/machine-learning/glossary/#epoch) だけで行います。下記のコードセルでは、テスト用データセットのサンプルをひとつずつ処理し、モデルの予測と実際のラベルを比較します。この結果を使ってテスト用データセット全体でのモデルの正解率を計算します。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Tw03-MK1cYId",
        "colab": {}
      },
      "source": [
        "test_accuracy = tf.keras.metrics.Accuracy()\n",
        "\n",
        "for (x, y) in test_dataset:\n",
        "  logits = model(x)\n",
        "  prediction = tf.argmax(logits, axis=1, output_type=tf.int32)\n",
        "  test_accuracy(prediction, y)\n",
        "\n",
        "print(\"Test set accuracy: {:.3%}\".format(test_accuracy.result()))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "HcKEZMtCOeK-"
      },
      "source": [
        "たとえば最後のバッチを見ると、モデルはだいたい正確です。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "uNwt2eMeOane",
        "colab": {}
      },
      "source": [
        "tf.stack([y,prediction],axis=1)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "7Li2r1tYvW7S"
      },
      "source": [
        "## 訓練済みモデルを使った予測\n",
        "\n",
        "モデルを訓練し、それが、完全ではないにせよアヤメの品種の分類でよい性能を示すことを「証明」しました。次は、この訓練済みモデルを使って [ラベルなしサンプル(unlabeled examples)](https://developers.google.com/machine-learning/glossary/#unlabeled_example) つまり、特徴量だけでラベルのないサンプルについて予測を行ってみましょう。\n",
        "\n",
        "現実世界では、ラベルなしサンプルはアプリや CSV ファイル、データフィードなどさまざまな異なるソースからやってきます。ここでは、ラベルを予測するため、3つのラベルなしサンプルを手動で与えることにします。ラベルの番号は、下記のように名前を表していることを思い出してください。\n",
        "\n",
        "* `0`: Iris setosa\n",
        "* `1`: Iris versicolor\n",
        "* `2`: Iris virginica"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "kesTS5Lzv-M2",
        "colab": {}
      },
      "source": [
        "predict_dataset = tf.convert_to_tensor([\n",
        "    [5.1, 3.3, 1.7, 0.5,],\n",
        "    [5.9, 3.0, 4.2, 1.5,],\n",
        "    [6.9, 3.1, 5.4, 2.1]\n",
        "])\n",
        "\n",
        "predictions = model(predict_dataset)\n",
        "\n",
        "for i, logits in enumerate(predictions):\n",
        "  class_idx = tf.argmax(logits).numpy()\n",
        "  p = tf.nn.softmax(logits)[class_idx]\n",
        "  name = class_names[class_idx]\n",
        "  print(\"Example {} prediction: {} ({:4.1f}%)\".format(i, name, 100*p))"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}